CREATE OR REPLACE FUNCTION "pipeline"."linear_regression"("source_table" varchar, "out_table_model" varchar, "out_table_result" varchar, "ground_truth" varchar, "feature_cols" varchar, "grouping_cols" varchar)
  RETURNS "pg_catalog"."text" AS $BODY$
    import json

    def getTableMetaInfo(table_name): 
        tmps = table_name.split(".")
        if len(tmps) > 0:
            table_name = tmps[1]	
        sql_str = "select column_name, data_type from information_schema.columns where table_name = '%s'" %(table_name)
        rs = plpy.execute(sql_str)
        meta = {}
        for line in rs:
            meta[line['column_name']] = line['data_type']
        return meta 
    
    def linearRegressionTrain(sourceTable, modelTable, groundTruth, featureCols, groupingCols):
        sqlTpl = "SELECT * FROM madlib.linregr_train('{}', '{}', '{}', '{}', {})";
        dependVars = "ARRAY[1,{}]".format(",".join(featureCols))
        grouping = "NULL"
        msg = "success"
        flag = True
        if not (groupingCols == None or len(groupingCols) == 0):
            grouping = "'{}'".format(",".join(groupingCols))
        sqlStr = sqlTpl.format(sourceTable, modelTable, groundTruth, dependVars, grouping)
        try:
            plpy.execute(sqlStr)
        except Exception as e:
            msg = str(e) + " sql=" + sqlStr
            plpy.execute("DROP TABLE IF EXISTS {}".format(modelTable))
            return False, msg
        return flag, msg
        
    def linearRegressionPredict(sourceTable, modelTable, resultTable, featureCols):
        flag = True
        msg = "success"
        sqlTpl = """
            CREATE TABLE {} AS SELECT a.*, madlib.linregr_predict({}, m.coef) as predict FROM {} a, {} m
        """
        dependVars = "ARRAY[1, {}]".format(",".join(featureCols))
        sqlStr = sqlTpl.format(resultTable, dependVars, sourceTable, modelTable)
        try:
            plpy.execute(sqlStr)
        except Exception as e:
            msg = str(e) + " sql=" + sqlStr
            plpy.execute("DROP TABLE IF EXISTS {}".format(modelTable))
            plpy.execute("DROP TABLE IF EXISTS {}".format(resultTable))
            return False, msg
        return flag, msg

    def logError(msg, status, result):
        result["status"] = status
        result["error_msg"] = msg
        return json.dumps(result)
        
    def begin(source_table, out_table_model, out_table_result, ground_truth, feature_cols, grouping_cols):
        result = {
            "status": 0,
            "error_msg": "success",
            "result": {
                "input_params": {
                    "source_table": source_table,
                    "out_table_model": out_table_model,
                    "out_table_result": out_table_result,
                    "feature_cols": feature_cols,
                    "ground_truth": ground_truth,
                    "grouping_cols": grouping_cols
                },
                "output_params": [
                ]
            }
        }
        featureCols = feature_cols.split(",")
        grouping = []
        if not grouping_cols == "NULL":
            grouping = grouping_cols.split(",")
        flag, msg = linearRegressionTrain(source_table, out_table_model, ground_truth, featureCols, grouping)
        if not flag:
            plpy.execute("DROP TABLE IF EXISTS {}".format(out_table_model))
            plpy.execute("DROP TABLE IF EXISTS {}_summary".format(out_table_model))
            return logError(msg, 500, result)
        flag, msg = linearRegressionPredict(source_table, out_table_model, out_table_result, featureCols)
        if not flag:
            plpy.execute("DROP TABLE IF EXISTS {}".format(out_table_model))
            plpy.execute("DROP TABLE IF EXISTS {}_summary".format(out_table_model))
            plpy.execute("DROP TABLE IF EXISTS {}".format(out_table_result))
            return logError(msg, 500, result)
        
        modelTableMeata = getTableMetaInfo(out_table_model)
        modelTable = {
            "out_table_name": out_table_model,
            "output_cols": modelTableMeata.keys()
        }
        plpy.execute("DROP TABLE IF EXISTS {}_summary".format(out_table_model))
        resultTableMeata = getTableMetaInfo(out_table_result)
        resultTable = {
            "out_table_name": out_table_result,
            "output_cols": resultTableMeata.keys()
        }        
        
        result["result"]["output_params"].append(resultTable)
        result["result"]["output_params"].append(modelTable)
        return json.dumps(result)

    if __name__ == "__main__":
        return begin(source_table, out_table_model, out_table_result, ground_truth, feature_cols, grouping_cols)
   
$BODY$
  LANGUAGE plpythonu VOLATILE
  COST 100